import torch

from .scores import (
    hessian_importance,
    hessianfree_importance,
    llmpruner_1_importance,
    random_importance,
    wanda_importance,
    weight_importance,
)


class BasePruner:
    def __init__(
        self,
        scores="weight",
        sparsity_ratio=0.0,
        n_samples=0,
        seed=0,
        dataset_name="c4",
        eval_dataset_name="c4",
        **kwargs
    ):
        self.scores = scores
        self.W_metrics = None
        self.n_samples = n_samples
        self.seed = seed
        self.dataset_name = dataset_name
        self.sparsity_ratio = sparsity_ratio
        self.eval_dataset_name = eval_dataset_name

    def calculate_scores(self, model, tokenizer, device):
        if self.scores == "weight":
            self.W_metrics = weight_importance(model, tokenizer, device)
        elif self.scores == "wanda":
            self.W_metrics = wanda_importance(
                model, tokenizer, self.n_samples, device, self.seed, self.dataset_name
            )
        elif self.scores == "hessianfree":
            self.W_metrics = hessianfree_importance(
                model,
                tokenizer,
                self.n_samples,
                device,
                self.seed,
                self.dataset_name,
                self.eval_dataset_name,
            )
        elif self.scores == "random":
            self.W_metrics = random_importance(model, tokenizer, device)
        elif self.scores == "llmpruner_1":
            self.W_metrics = llmpruner_1_importance(
                model,
                tokenizer,
                self.n_samples,
                device,
                self.seed,
                self.dataset_name,
            )
        elif self.scores == "hessian":
            self.W_metrics = hessian_importance(
                model,
                tokenizer,
                self.n_samples,
                self.seed,
                device,
                self.dataset_name,
                self.eval_dataset_name,
            )
        else:
            raise NotImplementedError

    def prune(self, model, tokenizer, device):
        pass


def get(**kwargs):
    return BasePruner(**kwargs)
